{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "SeFa",
      "provenance": [],
      "collapsed_sections": [],
      "toc_visible": true
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "accelerator": "GPU"
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "qJDJLE3v0HNr"
      },
      "source": [
        "# Fetch Codebase"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "JqiWKjpFa0ov",
        "cellView": "form"
      },
      "source": [
        "#@title\n",
        "import os\n",
        "os.chdir('/content')\n",
        "CODE_DIR = 'sefa'\n",
        "!git clone https://github.com/genforce/sefa.git $CODE_DIR\n",
        "os.chdir(f'./{CODE_DIR}')"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "hQ_IXBZr8YcJ"
      },
      "source": [
        "# Define Utility Functions"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "ijKTlG5GeTd3",
        "cellView": "form"
      },
      "source": [
        "#@title\n",
        "import os.path\n",
        "import io\n",
        "import IPython.display\n",
        "import numpy as np\n",
        "import cv2\n",
        "import PIL.Image\n",
        "\n",
        "import torch\n",
        "\n",
        "from models import parse_gan_type\n",
        "from utils import to_tensor\n",
        "from utils import postprocess\n",
        "from utils import load_generator\n",
        "from utils import factorize_weight\n",
        "\n",
        "\n",
        "def sample(generator, gan_type, num=1, seed=0):\n",
        "  \"\"\"Samples latent codes.\"\"\"\n",
        "  torch.manual_seed(seed)\n",
        "  codes = torch.randn(num, generator.z_space_dim).cuda()\n",
        "  if gan_type == 'pggan':\n",
        "    codes = generator.layer0.pixel_norm(codes)\n",
        "  elif gan_type == 'stylegan':\n",
        "    codes = generator.mapping(codes)['w']\n",
        "    codes = generator.truncation(codes, trunc_psi=0.7, trunc_layers=8)\n",
        "  elif gan_type == 'stylegan2':\n",
        "    codes = generator.mapping(codes)['w']\n",
        "    codes = generator.truncation(codes, trunc_psi=0.5, trunc_layers=18)\n",
        "  codes = codes.detach().cpu().numpy()\n",
        "  return codes\n",
        "\n",
        "\n",
        "def synthesize(generator, gan_type, codes):\n",
        "  \"\"\"Synthesizes images with the give codes.\"\"\"\n",
        "  if gan_type == 'pggan':\n",
        "    images = generator(to_tensor(codes))['image']\n",
        "  elif gan_type in ['stylegan', 'stylegan2']:\n",
        "    images = generator.synthesis(to_tensor(codes))['image']\n",
        "  images = postprocess(images)\n",
        "  return images\n",
        "\n",
        "\n",
        "def imshow(images, col, viz_size=256):\n",
        "  \"\"\"Shows images in one figure.\"\"\"\n",
        "  num, height, width, channels = images.shape\n",
        "  assert num % col == 0\n",
        "  row = num // col\n",
        "\n",
        "  fused_image = np.zeros((viz_size * row, viz_size * col, channels), dtype=np.uint8)\n",
        "\n",
        "  for idx, image in enumerate(images):\n",
        "    i, j = divmod(idx, col)\n",
        "    y = i * viz_size\n",
        "    x = j * viz_size\n",
        "    if height != viz_size or width != viz_size:\n",
        "      image = cv2.resize(image, (viz_size, viz_size))\n",
        "    fused_image[y:y + viz_size, x:x + viz_size] = image\n",
        "\n",
        "  fused_image = np.asarray(fused_image, dtype=np.uint8)\n",
        "  data = io.BytesIO()\n",
        "  PIL.Image.fromarray(fused_image).save(data, 'jpeg')\n",
        "  im_data = data.getvalue()\n",
        "  disp = IPython.display.display(IPython.display.Image(im_data))\n",
        "  return disp"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Q7gkmrVW8eR1"
      },
      "source": [
        "# Select a Model"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "NoWI4fPQ6Gnf"
      },
      "source": [
        "#@title { display-mode: \"form\", run: \"auto\" }\n",
        "model_name = \"stylegan_animeface512\" #@param ['stylegan_animeface512', 'stylegan_car512', 'stylegan_cat256', 'pggan_celebahq1024', 'stylegan_bedroom256']\n",
        "\n",
        "generator = load_generator(model_name)\n",
        "gan_type = parse_gan_type(generator)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "zDStH1O5t1KC"
      },
      "source": [
        "# Sample Latent Codes"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "qlRGKZbJt9hA"
      },
      "source": [
        "#@title { display-mode: \"form\", run: \"auto\" }\n",
        "\n",
        "num_samples = 3 #@param {type:\"slider\", min:1, max:8, step:1}\n",
        "noise_seed = 0 #@param {type:\"slider\", min:0, max:1000, step:1}\n",
        "\n",
        "codes = sample(generator, gan_type, num_samples, noise_seed)\n",
        "images = synthesize(generator, gan_type, codes)\n",
        "imshow(images, col=num_samples)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "MmRPN3xz8jCH"
      },
      "source": [
        "# Factorize & Edit"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "ccONBF60mVir"
      },
      "source": [
        "#@title { display-mode: \"form\", run: \"auto\" }\n",
        "\n",
        "layer_idx = \"0-1\" #@param ['all', '0-1', '2-5', '6-13']\n",
        "semantic_1 = 0 #@param {type:\"slider\", min:-3.0, max:3.0, step:0.1}\n",
        "semantic_2 = 0 #@param {type:\"slider\", min:-3.0, max:3.0, step:0.1}\n",
        "semantic_3 = 0 #@param {type:\"slider\", min:-3.0, max:3.0, step:0.1}\n",
        "semantic_4 = 0 #@param {type:\"slider\", min:-3.0, max:3.0, step:0.1}\n",
        "semantic_5 = 0 #@param {type:\"slider\", min:-3.0, max:3.0, step:0.1}\n",
        "\n",
        "# Fast implementation to factorize the weight by SeFa.\n",
        "layers, boundaries, _ = factorize_weight(generator, layer_idx)\n",
        "\n",
        "new_codes = codes.copy()\n",
        "for sem_idx in range(5):\n",
        "  boundary = boundaries[sem_idx:sem_idx + 1]\n",
        "  step = eval(f'semantic_{sem_idx + 1}')\n",
        "  if gan_type == 'pggan':\n",
        "    new_codes += boundary * step\n",
        "  elif gan_type in ['stylegan', 'stylegan2']:\n",
        "    new_codes[:, layers, :] += boundary * step\n",
        "new_images = synthesize(generator, gan_type, new_codes)\n",
        "imshow(new_images, col=num_samples)"
      ],
      "execution_count": null,
      "outputs": []
    }
  ]
}